import traceback
import crcmod
import numpy as np
import os

def calculate_checksum(binary_string, parity_bits_count = 1):
    if parity_bits_count == 0:
        return ""
    # 奇偶校验
    count_ones = binary_string.count('1')
    parity_bit1 = '0' if count_ones % 2 == 0 else '1'
    
    if parity_bits_count == 1:
        return parity_bit1
    
    # 生成校验码所需的数据
    binary_data = int(binary_string, 2).to_bytes((len(binary_string) + 7) // 8, byteorder='big')
    
    # CRC-8校验
    crc8_func = crcmod.mkCrcFun(0x107, initCrc=0, xorOut=0x00)
    crc8_checksum = crc8_func(binary_data)
    parity_bit2 = format(crc8_checksum, '08b')[-1]
    
    if parity_bits_count == 2:
        return parity_bit1 + parity_bit2
    
    # CRC-16校验
    crc16_func = crcmod.mkCrcFun(0x11021, initCrc=0, xorOut=0x0000)
    crc16_checksum = crc16_func(binary_data)
    parity_bit3 = format(crc16_checksum, '016b')[-1]
    
    if parity_bits_count == 3:
        return parity_bit1 + parity_bit2 + parity_bit3
    
    # CRC-32校验
    crc32_func = crcmod.mkCrcFun(0x104C11DB7, initCrc=0, xorOut=0xFFFFFFFF)
    crc32_checksum = crc32_func(binary_data)
    parity_bit4 = format(crc32_checksum, '032b')[-1]
    
    return parity_bit1 + parity_bit2 + parity_bit3 + parity_bit4



def generate_possible_sequences(bitscore, anchor = "10", top_n=5):
    # 提取分数并排序
    scores = list(bitscore.values())
    sorted_scores = sorted(scores, reverse=True)


    def calculate_variance(groups):
        variances = []
        for group in groups:
            if len(group) > 0:
                variances.append(np.var(group))
            else:
                variances.append(0)
        return sum(variances)

    def calculate_score(sequence, bitscore):
        group0 = [bitscore[i] for i, bit in enumerate(sequence) if bit == '0']
        group1 = [bitscore[i] for i, bit in enumerate(sequence) if bit == '1']
        
        # 加入anchor对应的分数
        for i in range(len(sequence)):
            seq = sequence
            if i < len(anchor):
                seq = anchor
            if seq[i] == '0':
                group0.append(bitscore[i])
            else:
                group1.append(bitscore[i])

        return calculate_variance([group0, group1])

    possible_sequences = set()

    # 尝试所有可能的阈值
    for threshold in sorted_scores:
        sequence = "".join('1' if bitscore[i] >= threshold else '0' for i in range(len(bitscore)))
        possible_sequences.add(sequence)

    # 对所有可能的序列进行得分
    scored_sequences = []
    for sequence in possible_sequences:
        score = calculate_score(sequence, bitscore)
        scored_sequences.append((sequence, score))

    # 根据得分排序并选择得分最低的几个序列
    top_sequences = sorted(scored_sequences, key=lambda x: x[1])[:top_n]

    return top_sequences

def group_scores_by_seeds(seeds, scores):
    grouped_scores = {}
    for seed, score in zip(seeds, scores):
        if seed not in grouped_scores:
            grouped_scores[seed] = []
        grouped_scores[seed].append(score)
    return grouped_scores

def process_text(text, text_orig, tokenizer, args, key01, usermode):
    import torch
    from wm import OpenaiDetector, MarylandDetector
    try:
        if args.method_detect == "openai" or args.method_detect is None:
            detector = OpenaiDetector(tokenizer, args.ngram, args.seed, args.seeding, args.hash_key)
        elif args.method_detect == "maryland":
            detector = MarylandDetector(tokenizer, args.ngram, args.seed, args.seeding, args.hash_key, gamma=args.gamma, delta=args.delta)
        detector.args = args
        mscores = {}
        mpvalues = {}
        maxseed = 256
        if "brick3" in usermode:
            maxseed = 256 ** len(key01)
        if "wowm" in usermode:
            maxseed = usermode["wowm"]
        if "mix" in usermode:
            maxseed = usermode.get("maxkey", 1000)
        for seed in range(maxseed):
            scores_no_aggreg = detector.get_scores_by_t(
                [text], 
                scoring_method=args.scoring_method, 
                payload_max=args.payload_max, 
                chars_seed=seed
            )
            scores_no_aggreg, seeds = scores_no_aggreg
            seed_scores = group_scores_by_seeds(seeds[0], scores_no_aggreg[0])
            mscores[seed] = {pos: detector.aggregate_scores([seed_scores[pos]]) for pos in seed_scores}
            mpvalues[seed] = {pos: detector.get_pvalues([seed_scores[pos]]) for pos in seed_scores}

        if args.payload_max:
            for seed in mscores:
                for pos in mscores[seed]:
                    seed_payloads = np.argmin(mpvalues[seed][pos], axis=1).tolist()
                    mscores[seed][pos] = [float(s[payload]) for s, payload in zip(mscores[seed][pos], seed_payloads)]
    except Exception as e:
        print(e)
        traceback.print_exc()
    #progress_bar.update(1)
    return mscores, key01

import concurrent.futures
import tqdm

def calc_alt2_scores_v1(thres, filter_mscores_list, usermode):
    if type(thres) is tuple:
        thres, thresl, thresu = thres
    keys, predict_keys, scores = [], [], []
    for m in filter_mscores_list:
        keys.append(m[0][1] if m[0] is not None else None)
        pred = None
        if len(m[1][0]) != 2 or m[1][0][0][0] < thres:
            pred = None
            score = m[1][0][0][0] if 0 in m[1][0] else 0.0
        else:
            s = [m[1][i][1][0] for i in range(len(m[1]))]
            pred = int(np.argmax(s))
            score = np.max(s)
        if "ckmax" in usermode:
            s = [m[1][i][1][0] for i in range(len(m[1]))] if len(m[1][0]) == 2 else -1
            score = float(np.max(s))

            if 0 < score < thresl:
                pred = None
            if score > thresu:
                pred = int(np.argmax(s))
        if "cmb" in usermode:
            if len(m[1][0]) == 2:
                s = [m[1][i][1][0] for i in range(len(m[1]))]
                if np.max(s) > usermode.get("thres", 1.5):
                    pred = int(np.argmax(s))
        if "joind" in usermode:
            if len(m[1][0]) != 2:
                pred = None
            else:
                maxjoincnt = usermode.get("joind") or 3
                ms = max([m[1][mi][0][0] for mi in range(maxjoincnt)])
                if ms < usermode.get("thres", 1.3):
                    pred = None
                else:
                    s = [m[1][i][1][0] for i in range(len(m[1]))]
                    pred = int(np.argmax(s))
        predict_keys.append(pred)
        scores.append(score)
    return keys, predict_keys, scores

def calc_alt2_scores(params, two_scores, usermode):
    keys, predict_keys, scores = [], [], []
    #params = 1.4, 0
    thres = params[0] if type(params) is tuple else params
    for t in two_scores:
        keys.append(t["gold"])
        pred = None
        if t["worn_score"] < thres or t["max_id"] < 0:
            pred = None
            score = t["worn_score"]
        else:
            pred = t["max_id"]
            score = t["max_id_score"]
        if "ckand" in usermode:
            thres, thres1 = params
            score = t["max_id_score"]
            if "mm" in usermode:
                score -= t["id_mean_score"]
            if "msnd" in usermode:
                score -= t["id_second_score"]
                thres1 *= 0.1
            if score < thres1:
                pred = None
        if "ckcmp" in usermode:
            thres, thres1 = params
            if t["max_id_score"] - t["id_mean_score"] < thres1 - 4.0:
                pred = None
        if "ckmax" in usermode:
            thres, thresl, thresu = params
            score = t["max_id_score"]
            if "mm" in usermode:
                score -= t["id_mean_score"]
            if 0 < score < thresl:
                pred = None
            if score > thresu:
                pred = t["max_id"]
        if "ckadd" in usermode:
            thres, weight = params
            def sigmoid(x):
                return 1 / (1 + np.exp(-x))
            if t["worn_score"] < 0 or t["max_id"] < 0:
                pred = None
            elif sigmoid(t["worn_score"]) + sigmoid(t["max_id_score"]) * weight > thres:
                pred = t["max_id"]
            else:
                pred = None
        if "onlyid" in usermode:
            score = t["max_id_score"]
            if "mm" in usermode:
                score -= t["id_mean_score"]
            if "msnd" in usermode:
                score -= t["id_second_score"]
            if score > thres:
                pred = t["max_id"]
            else:
                pred = None
        predict_keys.append(pred)
        scores.append(score)
        #if ((pred if pred is None else 1) != (t[0] if t[0] is None else 1)):
        #    print(t, pred)
    return keys, predict_keys, scores



def process_thresholds(thres_combinations, dev_list, usermode, score_func):
    results = {}
    for thres_combo in thres_combinations:
        thres, thresl, thresu = thres_combo
        keys, predict_keys, scores = score_func((thres, thresl, thresu), dev_list, usermode)
        accuracy = sum([x == y for x, y in zip(keys, predict_keys)]) / len(keys)
        results[thres_combo] = accuracy
    return results

def parallel_tune_dev(threses, dev_list, usermode, score_func, num_blocks=90):
    thres_combinations = [(thres, thresl, thresu) for thres in threses for thresl in threses for thresu in threses]
    
    # 分割组合成 num_blocks 块
    k, m = divmod(len(thres_combinations), num_blocks)
    blocks = [thres_combinations[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(num_blocks)]

    results = {}
    with concurrent.futures.ProcessPoolExecutor(max_workers=min(os.cpu_count(), 48) - 1) as executor:
        futures = {executor.submit(process_thresholds, block, dev_list, usermode, score_func): i for i, block in enumerate(blocks)}
        for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
            block_results = future.result()
            results.update(block_results)
    
    return results